# Copyright (C) 2019-2021 Ruhr West University of Applied Sciences, Bottrop, Germany
# AND Elektronische Fahrwerkssysteme, Gaimersheim, Germany
#
# This Source Code Form is subject to the terms of the Apache License 2.0
# If a copy of the APL2 was not distributed with this
# file, You can obtain one at https://www.apache.org/licenses/LICENSE-2.0.txt.

from typing import List, Dict, Tuple
import json
import numpy as np
from tqdm import tqdm
import torch

try:
    from detectron2.data import DatasetCatalog, MetadataCatalog
    from detectron2.structures.boxes import Boxes, BoxMode, pairwise_iou
except ImportError:
    raise ImportError(
        "Need detectron2 to evaluate object detection calibration. You can get the latest version at https://github.com/facebookresearch/detectron2")


def read_image_ids(filename: str) -> Tuple[List, List]:
    """
    Read a JSON file that holds the image ids used for training and testing.
    Parameters
    ----------
    filename : str
        Path to JSON file describing training and testing indices.

    Returns
    -------
    Tuple[List, List]
        Tuple of two lists with frame indices for training and testing.
    """

    with open(filename, "r") as open_file:
        config = json.load(open_file)

    # assert fields "train" and "test" within JSON file
    assert "train" in config, "Field \'train\' must be given in JSON file root for specifying training IDs."
    assert "test" in config, "Field \'test\' must be given in JSON file root for specifying testing IDs."

    train_ids = config["train"]
    test_ids = config["test"]

    return train_ids, test_ids


def read_json(filename: str, score_threshold: float = None) -> List[Dict]:
    """
    Read JSON prediction file. This file holds predictions in COCO annotation format generated by the Detectron2
    framework (or any other framework that outputs predictions in COCO annotations format with scores attached
    to each prediction). Detections below a certain score_threshold are neglected for training/testing.

    Parameters
    ----------
    filename : str
        Path to JSON file.
    score_threshold : float, optional, default: None
        Optional score threshold for detections.

    Returns
    -------
    List[Dict]
        List of dictionaries where each dict represents a frame with certain detections and their according scores,
        classes and bounding boxes.
    """

    print("Read JSON file %s" % filename)
    with open(filename, "r") as open_file:
        content = json.load(open_file)

    # iterate over all detections and perform a preprocessing. Detections are "flattened" across all frames.
    # We sort each detection to its according frame
    img_ids, frames = [], []
    for prediction in tqdm(content, desc="Preprocess content"):

        # if current prediction is below score threshold, continue
        if score_threshold is not None and prediction['score'] < score_threshold:
            continue

        # create new dictionary if current frame has not been processed so far
        if prediction['image_id'] not in img_ids:
            img_ids.append(prediction['image_id'])
            frames.append({'image_id': prediction['image_id'], 'category_ids': [prediction['category_id']],
                           'bboxes': [prediction['bbox']], 'scores': [prediction['score']]})

        # append to existing object otherwise
        else:
            frame = next((frame for frame in frames if frame["image_id"] == prediction['image_id']), None)
            frame['category_ids'].append(prediction['category_id'])
            frame['bboxes'].append(prediction['bbox'])
            frame['scores'].append(prediction['score'])

    # convert lists to NumPy arrays
    for frame in frames:
        frame['category_ids'] = np.array(frame['category_ids'])
        frame['bboxes'] = np.array(frame['bboxes'])
        frame['scores'] = np.array(frame['scores'])

    return frames


def match_frames_with_groundtruth(frames: List[Dict], dataset: str, ious: List[float]) -> List[Dict]:
    """
    For calbration training and evaluation, we need to assess the precision over all frames. This means, that we
    need to match each detection with a ground-truth sample. Mark a matched object with m=1 and a not-matched
    object with m=0. Matched criterion: IoU above a certain IoU threshold, same predicted class as ground-truth
    object.

    Parameters
    ----------
    frames : List[Dict]
        List of dictionaries containing the neural network predictions.
    dataset : str
        String of the used dataset (see detectron2 registered datasets).
    ious : List[float]
        List with IoU scores used for evaluation.

    Returns
    -------
    List[Dict]
        List of dictionaries where each dict represents a frame with certain detections and their according scores,
        classes and bounding boxes. In addition, each frame also holds the field 'matched' holding the boolean tensor
        if a prediction has matched a ground-truth object or not.
    """

    # get information about current dataset using Detectron2
    dataset_dicts = DatasetCatalog.get(dataset)
    meta = MetadataCatalog.get(dataset)

    # iterate over all frames and match with ground-truth
    for frame in tqdm(frames, desc="Match frames with ground-truth annotations"):

        # find according ground-truth frame and convert boxes to XYXY format
        # the XYXY format is required for the Boxes wrapper module that is used to evaluate the IoU scores
        d = next((d for d in dataset_dicts if d["image_id"] == frame['image_id']), None)
        gt_categories = torch.from_numpy(np.array([x['category_id'] for x in d['annotations']]))
        gt_boxes = np.array([BoxMode.convert(x['bbox'], from_mode=x['bbox_mode'], to_mode=BoxMode.XYXY_ABS) for x in d['annotations']])

        frame['height'] = d['height']
        frame['width'] = d['width']

        # also put the detected predictions in the Boxes wrapper
        pred_boxes = Boxes(BoxMode.convert(torch.from_numpy(frame['bboxes']), from_mode=BoxMode.XYWH_ABS, to_mode=BoxMode.XYXY_ABS)) # N
        pred_labels = torch.tensor([meta.thing_dataset_id_to_contiguous_id[x] for x in frame['category_ids']]) if hasattr(meta, "thing_dataset_id_to_contiguous_id") else torch.from_numpy(frame['category_ids'])
        gt_boxes = Boxes(torch.from_numpy(gt_boxes)) # M

        # if either no prediction is given or not ground-truth boxes, skip current frame
        if len(pred_boxes) == 0:
            continue

        if len(gt_boxes) == 0:
            frame['matched'] = [np.zeros_like(pred_labels.numpy()) for _ in ious]
            continue

        # calculate pairwise IoU and only consider the max score. This yields a (N, M) tensor
        iou_matrix = pairwise_iou(pred_boxes, gt_boxes)
        iou_scores, idx = torch.max(iou_matrix, dim=1)
        gt_labels = gt_categories[idx]
        pred_labels = pred_labels

        # mark predictions as matched that have the correct class label and are above a certain IoU threshold
        frame['matched'] = []
        for iou in ious:
            matched = torch.where((iou_scores >= iou) & (pred_labels == gt_labels), torch.ones_like(iou_scores), torch.zeros_like(iou_scores))
            frame['matched'].append(matched.numpy())

    return frames


def save_frames(frames: List[Dict], filename: str):
    """
    Counter-part to 'read_json' function. Write frames to JSON format.

    Parameters
    ----------
    frames : List[Dict]
        List of dictionaries containing the neural network predictions.
    filename : str
        Path to JSON file.
    """

    output = []
    for frame in frames:
        image_id = frame['image_id']
        for category_id, bbox, score in zip(frame['category_ids'], frame['bboxes'], frame['scores']):
            output.append({'image_id': int(image_id), 'category_id': int(category_id), 'bbox': bbox.tolist(), 'score': float(score)})

    with open(filename, "w") as open_file:
        json.dump(output, open_file)
